{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]='1'\n",
    "# from alphagen.config import *\n",
    "# from alphagen.data.tokens import *\n",
    "from alphagen.models.alpha_pool import AlphaPoolBase, AlphaPool\n",
    "from alphagen.rl.env.core import AlphaEnvCore\n",
    "import torch.nn.functional as F\n",
    "from gan.dataset import Collector\n",
    "from gan.network.generater import NetG_DCGAN\n",
    "from gan.network.masker import NetM\n",
    "from gan.network.predictor import NetP, train_regression_model,train_regression_model_with_weight\n",
    "from alphagen.rl.env.wrapper import SIZE_ACTION,action2token\n",
    "\n",
    "from alphagen_generic.features import open_\n",
    "from gan.utils import Builders\n",
    "from alphagen_generic.features import *\n",
    "from alphagen.data.expression import *\n",
    "\n",
    "from gan.utils.data import get_data_by_year\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "instruments: str = \"csi300\"\n",
    "freq = 'day'\n",
    "save_name = 'test'\n",
    "window = float('inf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from alphagen.utils.correlation import batch_pearsonr,batch_spearmanr\n",
    "device = 'cuda:0'\n",
    "result = []\n",
    "pred_dfs = {}\n",
    "for n_factors in [1,10,20,50,100]:\n",
    "    for seed in range(5):\n",
    "        cur_seed_ic = []\n",
    "        cur_seed_ric = []\n",
    "        all_pred_df_list = []\n",
    "        for train_end in range(2020,2021):\n",
    "            print(n_factors,seed,train_end)\n",
    "            returned = get_data_by_year(\n",
    "                train_start = 2010,train_end=train_end,valid_year=train_end+1,test_year =train_end+2,\n",
    "                instruments=instruments, target=target,freq=freq,\n",
    "            )\n",
    "            data_all, data,data_valid,data_valid_withhead,data_test,data_test_withhead,name = returned\n",
    "            \n",
    "\n",
    "            path = f'out/{save_name}_{instruments}_{train_end}_{seed}/z_bld_zoo_final.pkl'\n",
    "            tensor_save_path = f'out/{save_name}_{instruments}_{train_end}_{seed}/pred_{train_end}_{n_factors}_{window}_{seed}.pt'\n",
    "\n",
    "            pred = torch.load(tensor_save_path).to(device)\n",
    "            tgt = target.evaluate(data_all)\n",
    "            \n",
    "            \n",
    "            ones = torch.ones_like(tgt)\n",
    "            ones = ones * torch.nan\n",
    "            ones[-data_test.n_days:] = pred\n",
    "            cur_df = data_all.make_dataframe(ones)\n",
    "            all_pred_df_list.append(cur_df.unstack().iloc[-data_test.n_days:].stack())\n",
    "            \n",
    "            tgt = tgt[-data_test.n_days:].to(device)\n",
    "            \n",
    "            \n",
    "            ic_s = torch.nan_to_num(batch_pearsonr(pred,tgt),nan=0)\n",
    "            rank_ic_s = torch.nan_to_num(batch_spearmanr(pred,tgt),nan=0)\n",
    "\n",
    "            cur_seed_ic.append(ic_s)\n",
    "            cur_seed_ric.append(rank_ic_s)\n",
    "            \n",
    "        pred_dfs[f\"{n_factors}_{seed}\"] = pd.concat(all_pred_df_list,axis=0)\n",
    "        ic = torch.cat(cur_seed_ic)\n",
    "        rank_ic = torch.cat(cur_seed_ric)\n",
    "\n",
    "        ic_mean = ic.mean().item()\n",
    "        rank_ic_mean = rank_ic.mean().item()\n",
    "        ic_std = ic.std().item()\n",
    "        rank_ic_std = rank_ic.std().item()\n",
    "        tmp = dict(\n",
    "            seed = seed,\n",
    "            num = n_factors,\n",
    "            ic = ic_mean,\n",
    "            ric = rank_ic_mean,\n",
    "            icir = ic_mean/ic_std,\n",
    "            ricir = rank_ic_mean/rank_ic_std,\n",
    "        )\n",
    "        result.append(tmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "run_result = pd.DataFrame(result).groupby(['num','seed']).mean().groupby('num').agg(['mean','std'])\n",
    "print(run_result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38n1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
